import time
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from config import Config
from model import BertForIntentClassification
from utils import *
from trainer import Trainer
from flask import Flask, request, jsonify

app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False

config = Config()
# 加载分词器和模型
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_path)
model = BertForIntentClassification(config)

if config.load_model:
    model.load_state_dict(torch.load(config.model_load_path))

print(f"Total parameters:{get_total_params(model)}") # 102M

# 加载数据
if config.do_train:
    train_examples = get_examples(config.train_file_path, "train")
    train_features = get_features(train_examples, tokenizer, config)
    train_dataset = IntentDataset(train_features)
    train_data_loader = DataLoader(train_dataset, batch_size=config.batchsize, shuffle=True)

if config.do_eval:
    eval_examples = get_examples(config.test_file_path, "eval")
    eval_features = get_features(eval_examples, tokenizer, config)
    eval_dataset = IntentDataset(eval_features)
    eval_data_loader = DataLoader(eval_dataset, batch_size=config.batchsize, shuffle=True)

model.to(config.device)
trainer = Trainer(model, config)
if config.do_train:
    print("---------------Train-----------------")
    start_time = time.time()
    trainer.train(train_data_loader)
    end_time = time.time()
    print(f"训练时间花费: {config.epoch} 轮花费 {end_time - start_time:.6f} 秒")
    if config.do_save:
        trainer.save_model(config.model_save_dir, save_name="model.1.pt")

if config.do_eval:
    print("----------------Eval------------------")
    trainer.eval(eval_data_loader)

if config.do_predict:
    print("---------------Predict----------------")
    input_text = "从合肥到上海可以到哪坐车？"
    # 测量预测时间
    start_time = time.time()
    trainer.predict(input_text, tokenizer)
    end_time = time.time()
    print(f"预测时间花费: {end_time - start_time:.6f} 秒")

def predict_intent_and_slots(text):
    result_dict = trainer.predict(text, tokenizer)
    return result_dict

@app.route('/predict', methods=['POST', 'GET'])
def predict():
    if request.method == 'POST':
        # 获取Java端传递的数据
        data = request.json
        # 调用预测逻辑函数获取意图预测结果
        text = data.get('text')
        print("获取到文本：" + text)
        result = predict_intent_and_slots(text)
        
        # 返回预测结果给Java端
        response = {'result': result}
        return jsonify(response)
    elif request.method == 'GET':
        text = "从合肥到上海可以到哪坐车？"
        result = predict_intent_and_slots(text)
        response = {'result': result}
        return jsonify(response)
# 启动Flask服务器，默认监听5000端口
app.run()